package zone.suplog.rpc.rabbitmq.proxy;

import com.rabbitmq.client.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.util.StringUtils;
import zone.suplog.rpc.spi.ProxyInterceptor;
import zone.suplog.rpc.spi.transfer.RpcRequest;
import zone.suplog.rpc.spi.transfer.RpcResponse;
import java.io.ByteArrayInputStream;
import java.io.ObjectInputStream;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class RabbitProxyInterceptor extends ProxyInterceptor {

    private Log log = LogFactory.getLog(this.getClass());

    /**
     * rabbit连接
     */
    private final Connection connection;

    /**
     * 所路由交换机
     */
    private final String exchange;

    /**
     * 最大连接时间
     */
    private final Long maxConnTimeout;

    /**
     * 系统名称
     */
    private final String applicationName;

    public RabbitProxyInterceptor(Connection connection, Class<?> targetClass, String applicationName, String exchange, Long maxConnTimeout) {
        super.targetClass = targetClass;
        this.connection = connection;
        this.exchange = exchange;
        this.maxConnTimeout = maxConnTimeout;
        this.applicationName = applicationName;
    }

    protected RpcResponse doRequest(RpcRequest rpcRequest) {
        try (Channel channel = Objects.requireNonNull(connection.getDelegate()).createChannel()) {
            String replyQueueName = channel.queueDeclare().getQueue();
            String corrId = UUID.randomUUID().toString();
            AMQP.BasicProperties properties = new AMQP.BasicProperties.Builder()
                    .correlationId(corrId).replyTo(replyQueueName).build();
            channel.confirmSelect();
            channel.queueDeclare(applicationName, false, false, true, null);
            channel.exchangeDeclare(exchange, BuiltinExchangeType.TOPIC);
            channel.queueBind(applicationName, exchange, applicationName);
            log.debug(rpcRequest.toString());
            channel.basicPublish(exchange, rpcRequest.getTargetInterface().getName(), properties, serializable(rpcRequest));
            // 等待消息确认
            channel.waitForConfirmsOrDie(maxConnTimeout);
            return reply(channel, properties, replyQueueName, corrId);
        } catch (Exception e) {
            return RpcResponse.error(e);
        }
    }

    private RpcResponse reply(Channel channel, AMQP.BasicProperties properties, String replyQueueName, String checkCorrId) throws Exception {
        BlockingQueue<RpcResponse> response = new ArrayBlockingQueue<>(1);
        // 服务端消息回复处理队列，如果服务端执行有异常，那么在内部需要抛出异常
        DeliverCallback deliverCallback = (String consumerTag, Delivery message) -> {
            if (properties.getCorrelationId().equals(checkCorrId)) {
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(message.getBody());
                ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
                try {
                    response.offer((RpcResponse) objectInputStream.readObject());
                } catch (ClassNotFoundException e) {
                    response.offer(RpcResponse.fail("没有找到服务", e));
                }
            }
        };
        // 消费回复信道消息，即服务端消息回复
        channel.basicConsume(replyQueueName, true, deliverCallback, (consumerTag -> {
        }));
        // 如果服务端没有回复
        RpcResponse take = response.poll(maxConnTimeout, TimeUnit.MILLISECONDS);
        if (StringUtils.hasText(replyQueueName)) {
            channel.queueDelete(replyQueueName);
        }
        if (Objects.isNull(take)) {
            throw new TimeoutException("rpc连接超时");
        }
        return take;
    }
}
